{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Mixing tensorflow models with gpflow\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2018-06-20T10:53:12.620472Z",
     "start_time": "2018-06-20T10:53:11.541346Z"
    }
   },
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import tensorflow as tf\n",
    "from matplotlib import pyplot as plt\n",
    "import gpflow\n",
    "from gpflow.test_util import notebook_niter, is_continuous_integration\n",
    "from scipy.cluster.vq import kmeans2\n",
    "\n",
    "float_type = gpflow.settings.float_type \n",
    "\n",
    "ITERATIONS = notebook_niter(1000)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Example 1: a convnet inside a gpflow model\n",
    "Here we'll use the gpflow functionality, but we'll put a non-gpflow model inside the kernel"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2018-06-20T10:53:16.316761Z",
     "start_time": "2018-06-20T10:53:12.621686Z"
    }
   },
   "outputs": [],
   "source": [
    "from tensorflow.examples.tutorials.mnist import input_data\n",
    "mnist = input_data.read_data_sets(\"./data/MNIST_data/\", one_hot=False)\n",
    "\n",
    "class Mnist:\n",
    "    input_dim = 784\n",
    "    Nclasses = 10\n",
    "    X = mnist.train.images.astype(float)\n",
    "    Y = mnist.train.labels.astype(float)[:, None]\n",
    "    Xtest = mnist.test.images.astype(float)\n",
    "    Ytest = mnist.test.labels.astype(float)[:, None]\n",
    "\n",
    "if is_continuous_integration():\n",
    "    mask = (Mnist.Y <= 1).squeeze()\n",
    "    Mnist.X = Mnist.X[mask][:105, 300:305]\n",
    "    Mnist.Y = Mnist.Y[mask][:105]\n",
    "    mask = (Mnist.Ytest <= 1).squeeze()\n",
    "    Mnist.Xtest = Mnist.Xtest[mask][:10, 300:305]\n",
    "    Mnist.Ytest = Mnist.Ytest[mask][:10]\n",
    "    Mnist.input_dim = 5\n",
    "    Mnist.Nclasses = 2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2018-06-20T10:53:16.335949Z",
     "start_time": "2018-06-20T10:53:16.317967Z"
    }
   },
   "outputs": [],
   "source": [
    "# a vanilla conv net\n",
    "# this gets 97.3% accuracy on MNIST when used on its own (+ final linear layer) after 20K iterations\n",
    "def cnn_fn(x, output_dim):\n",
    "    \"\"\"\n",
    "    Adapted from https://www.tensorflow.org/tutorials/layers\n",
    "    \"\"\"\n",
    "    conv1 = tf.layers.conv2d(\n",
    "          inputs=tf.reshape(x, [-1, 28, 28, 1]),\n",
    "          filters=32,\n",
    "          kernel_size=[5, 5],\n",
    "          padding=\"same\",\n",
    "          activation=tf.nn.relu)\n",
    "\n",
    "    pool1 = tf.layers.max_pooling2d(inputs=conv1, pool_size=[2, 2], strides=2)\n",
    "\n",
    "    conv2 = tf.layers.conv2d(\n",
    "          inputs=pool1,\n",
    "          filters=64,\n",
    "          kernel_size=[5, 5],\n",
    "          padding=\"same\",\n",
    "          activation=tf.nn.relu)\n",
    "    \n",
    "    pool2 = tf.layers.max_pooling2d(inputs=conv2, pool_size=[2, 2], strides=2)\n",
    "\n",
    "    pool2_flat = tf.reshape(pool2, [-1, 7 * 7 * 64])\n",
    "    return tf.layers.dense(inputs=pool2_flat, units=output_dim, activation=tf.nn.relu)\n",
    "\n",
    "if is_continuous_integration():\n",
    "    def cnn_fn(x, output_dim):\n",
    "        return tf.layers.dense(inputs=tf.reshape(x, [-1, Mnist.input_dim]), units=output_dim)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2018-06-20T10:53:16.475537Z",
     "start_time": "2018-06-20T10:53:16.337305Z"
    }
   },
   "outputs": [],
   "source": [
    "class KernelWithNN(gpflow.kernels.Kernel):\n",
    "    \"\"\"\n",
    "    This kernel class allows for easily adding a NN (or other function) to a GP model.\n",
    "    The kernel does not actually do anything with the NN.\n",
    "    \"\"\"\n",
    "    \n",
    "    def __init__(self, kern, f):\n",
    "        \"\"\"\n",
    "        kern.input_dim needs to be consistent with the output dimension of f\n",
    "        \"\"\"\n",
    "        super().__init__(kern.input_dim)\n",
    "        self.kern = kern\n",
    "        self._f = f\n",
    "        \n",
    "    def f(self, X):\n",
    "        if X is not None:\n",
    "            with tf.variable_scope('forward', reuse=tf.AUTO_REUSE):\n",
    "                return self._f(X)\n",
    "    \n",
    "    def _get_f_vars(self):\n",
    "        return tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope='forward')\n",
    "\n",
    "    @gpflow.autoflow([gpflow.settings.float_type, [None,None]])\n",
    "    def compute_f(self, X):\n",
    "        return self.f(X)\n",
    "    \n",
    "    def K(self, X, X2=None):\n",
    "        return self.kern.K(X, X2)\n",
    "    \n",
    "    def Kdiag(self, X):\n",
    "        return self.kern.Kdiag(X)\n",
    "\n",
    "class KernelSpaceInducingPoints(gpflow.features.InducingPointsBase):\n",
    "    pass\n",
    "\n",
    "# same Kuu as regular inducing points\n",
    "gpflow.features.Kuu.register(KernelSpaceInducingPoints, KernelWithNN)(\n",
    "    gpflow.features.Kuu.dispatch(gpflow.features.InducingPoints, gpflow.kernels.Kernel)\n",
    ")\n",
    "\n",
    "# Kuf is in NN output space\n",
    "@gpflow.features.dispatch(KernelSpaceInducingPoints, KernelWithNN, object)\n",
    "def Kuf(feat, kern, Xnew):\n",
    "    with gpflow.params_as_tensors_for(feat):\n",
    "        return kern.K(feat.Z, kern.f(Xnew))\n",
    "\n",
    "class NNComposedKernel(KernelWithNN):\n",
    "    \"\"\"\n",
    "    This kernel class applies f() to X before calculating K\n",
    "    \"\"\"\n",
    "    \n",
    "    def K(self, X, X2=None):\n",
    "        return super().K(self.f(X), self.f(X2))\n",
    "    \n",
    "    def Kdiag(self, X):\n",
    "        return super().Kdiag(self.f(X))\n",
    "    \n",
    "# we need to add these extra functions to the model so the tensorflow variables get picked up\n",
    "class NN_SVGP(gpflow.models.SVGP):\n",
    "    @property\n",
    "    def trainable_tensors(self):\n",
    "        return super().trainable_tensors + self.kern._get_f_vars()\n",
    "\n",
    "    @property\n",
    "    def initializables(self):\n",
    "        return super().initializables + self.kern._get_f_vars()\n",
    "    "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2018-06-20T11:10:46.858524Z",
     "start_time": "2018-06-20T10:53:16.477021Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "accuracy is 98.6800%\n"
     ]
    }
   ],
   "source": [
    "def ex1():\n",
    "    fX_dim = 5  \n",
    "    M = 100\n",
    "\n",
    "    # annoyingly only float32 and lower is supported by the conv layers \n",
    "    f = lambda x: tf.cast(cnn_fn(tf.cast(x, tf.float32), fX_dim), float_type)\n",
    "    kern = NNComposedKernel(gpflow.kernels.Matern32(fX_dim), f)\n",
    "\n",
    "    # build the model \n",
    "\n",
    "    lik = gpflow.likelihoods.MultiClass(Mnist.Nclasses)\n",
    "\n",
    "    Z = kmeans2(Mnist.X, M, minit='points')[0]\n",
    "\n",
    "    model = NN_SVGP(Mnist.X, Mnist.Y, kern, lik, Z=Z, num_latent=Mnist.Nclasses, minibatch_size=1000)\n",
    "\n",
    "    # use gpflow wrappers to train. NB all session handling is done for us\n",
    "    gpflow.training.AdamOptimizer(0.001).minimize(model, maxiter=ITERATIONS)\n",
    "\n",
    "    # predictions\n",
    "    m, v = model.predict_y(Mnist.Xtest)\n",
    "    preds = np.argmax(m, 1).reshape(Mnist.Ytest.shape)\n",
    "    correct = preds == Mnist.Ytest.astype(int)\n",
    "    acc = np.average(correct.astype(float)) * 100.\n",
    "\n",
    "    print('accuracy is {:.4f}%'.format(acc))\n",
    "\n",
    "gpflow.reset_default_graph_and_session()\n",
    "ex1()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2018-06-20T11:25:46.165332Z",
     "start_time": "2018-06-20T11:10:46.860361Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "accuracy is 97.5600%\n"
     ]
    }
   ],
   "source": [
    "def ex2b():\n",
    "    fX_dim = 5  \n",
    "    minibatch_size = notebook_niter(1000, test_n=10)\n",
    "    M = notebook_niter(100, test_n=5)\n",
    "\n",
    "    # annoyingly only float32 and lower is supported by the conv layers \n",
    "    f = lambda x: tf.cast(cnn_fn(tf.cast(x, tf.float32), fX_dim), float_type)\n",
    "    kern = KernelWithNN(gpflow.kernels.Matern32(fX_dim), f)\n",
    "    \n",
    "    ## reset inducing (they live in a different space as X, so need to be careful with this)\n",
    "    ind = np.random.choice(Mnist.X.shape[0], minibatch_size, replace=False)\n",
    "    \n",
    "    # currently we need a hack due to model initialization.\n",
    "    feat = KernelSpaceInducingPoints(np.empty((M, fX_dim)))\n",
    "    #feat = FFeature(Z_0)  # ideally, we could move the calculation of Z_0\n",
    "    \n",
    "    # build the model \n",
    "\n",
    "    lik = gpflow.likelihoods.MultiClass(Mnist.Nclasses)\n",
    "\n",
    "    #Z = kmeans2(Mnist.X, M, minit='points')[0]\n",
    "\n",
    "    model = NN_SVGP(Mnist.X, Mnist.Y, kern, lik, feat=feat, num_latent=Mnist.Nclasses, minibatch_size=minibatch_size)\n",
    "\n",
    "    fZ = model.kern.compute_f(Mnist.X[ind])\n",
    "    # Z_0 = kmeans2(fZ, M)[0] might fail\n",
    "    Z_0 = fZ[np.random.choice(len(fZ), M, replace=False)]\n",
    "    model.feature.Z = Z_0\n",
    "\n",
    "    # use gpflow wrappers to train. NB all session handling is done for us\n",
    "    gpflow.training.AdamOptimizer(0.001).minimize(model, maxiter=ITERATIONS)\n",
    "\n",
    "    # predictions\n",
    "    m, v = model.predict_y(Mnist.Xtest)\n",
    "    preds = np.argmax(m, 1).reshape(Mnist.Ytest.shape)\n",
    "    correct = preds == Mnist.Ytest.astype(int)\n",
    "    acc = np.average(correct.astype(float)) * 100.\n",
    "\n",
    "    print('accuracy is {:.4f}%'.format(acc))\n",
    "\n",
    "gpflow.reset_default_graph_and_session()\n",
    "ex2b()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Example 2: a gpflow model on top of a tensorflow model\n",
    "Now we'll do things the other way: we'll take a model implemented in pure tensorflow, and show how we can put a gpflow model on the top"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2018-06-20T11:40:21.347439Z",
     "start_time": "2018-06-20T11:25:46.167693Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "acc is 98.5900\n"
     ]
    }
   ],
   "source": [
    "def ex2():\n",
    "    minibatch_size = notebook_niter(1000, test_n=10)\n",
    "    gp_dim = 5\n",
    "    M = notebook_niter(100, test_n=5)\n",
    "\n",
    "    ## placeholders\n",
    "    X = tf.placeholder(tf.float32, [minibatch_size, Mnist.input_dim])  # fixed shape so num_data works in SVGP\n",
    "    Y = tf.placeholder(tf.float32, [minibatch_size, 1])\n",
    "    Xtest = tf.placeholder(tf.float32, [None, Mnist.input_dim])\n",
    "\n",
    "    ## build graph\n",
    "\n",
    "    with tf.variable_scope('cnn'):\n",
    "        f_X = tf.cast(cnn_fn(X, gp_dim), dtype=float_type)\n",
    "\n",
    "    with tf.variable_scope('cnn', reuse=True):\n",
    "        f_Xtest = tf.cast(cnn_fn(Xtest, gp_dim), dtype=float_type)\n",
    "\n",
    "    gp_model = gpflow.models.SVGP(f_X, tf.cast(Y, dtype=float_type), \n",
    "                                  gpflow.kernels.RBF(gp_dim), gpflow.likelihoods.MultiClass(Mnist.Nclasses), \n",
    "                                  Z=np.zeros((M, gp_dim)), # we'll set this later\n",
    "                                  num_latent=Mnist.Nclasses)\n",
    "\n",
    "    loss = -gp_model.likelihood_tensor\n",
    "\n",
    "    m, v = gp_model._build_predict(f_Xtest)\n",
    "    my, yv = gp_model.likelihood.predict_mean_and_var(m, v)\n",
    "\n",
    "    with tf.variable_scope('adam'):\n",
    "        opt_step = tf.train.AdamOptimizer(0.001).minimize(loss)\n",
    "\n",
    "    tf_vars = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope='adam')\n",
    "    tf_vars += tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope='cnn')\n",
    "\n",
    "    ## initialize\n",
    "    sess = tf.Session()\n",
    "    sess.run(tf.variables_initializer(var_list=tf_vars))\n",
    "    gp_model.initialize(session=sess)\n",
    "    \n",
    "    ## reset inducing (they live in a different space as X, so need to be careful with this)\n",
    "    ind = np.random.choice(Mnist.X.shape[0], minibatch_size, replace=False)\n",
    "\n",
    "    fZ = sess.run(f_X, feed_dict={X:Mnist.X[ind]})\n",
    "    # Z_0 = kmeans2(fZ, M)[0] might fail\n",
    "    Z_0 = fZ[np.random.choice(len(fZ), M, replace=False)]\n",
    "\n",
    "    def set_gp_param(param, value):\n",
    "        sess.run(tf.assign(param.unconstrained_tensor, param.transform.backward(value)))\n",
    "\n",
    "    set_gp_param(gp_model.feature.Z, Z_0)\n",
    "\n",
    "    ## train\n",
    "    for i in range(ITERATIONS):\n",
    "        ind = np.random.choice(Mnist.X.shape[0], minibatch_size, replace=False)\n",
    "        sess.run(opt_step, feed_dict={X:Mnist.X[ind], Y:Mnist.Y[ind]})\n",
    "\n",
    "    ## predict\n",
    "    preds = np.argmax(sess.run(my, feed_dict={Xtest:Mnist.Xtest}), 1).reshape(Mnist.Ytest.shape)\n",
    "    correct = preds == Mnist.Ytest.astype(int)\n",
    "    acc = np.average(correct.astype(float)) * 100.\n",
    "    print('acc is {:.4f}'.format(acc))\n",
    "\n",
    "gpflow.reset_default_graph_and_session()\n",
    "ex2()\n",
    "\n",
    "gpflow.reset_default_graph_and_session()"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
